import numpy as np
from scipy.optimize import minimize
import pandas as pd
from IPython.display import clear_output
import copy

from FBSM_library import Objective_Functional
from InitProblem_library import RKM_Dyn_System



def Objective_Functional_Direct_Method(u_t_vector, y_0, Arguments):

    """
    For given initial point (y_{0} from the main manuscript) and control function (u (\tau) in the main manuscript), this function 
    calculates the value of the objective functional for the corresponding state function (that appears from these y_{0} and u (\tau)). 
    
    Input:
    u_t_vector  - the control function (in vector representation)
    y_0         - the starting point of the system (see formula (4) from the main manuscript)
    Arguments   - hyperparameters
    """
    
    m = Arguments['m']
    M = Arguments['M']
    
    # The grid
    
    Grid = Arguments['Grid']
    
    #---------------------------------------------------------------------------
    
    # Reshape the guess for the control function into a list of vectors (of lenght m*M)
    
    u_t_vectors = [u_t_vector[ i*m*M : (i+1)*m*M ] for i in range(len(Grid)-1)]
    
    # Reshape the guess for the control function into matrix form
    
    u_t = [u_t_vectors[i].reshape(m, M) for i in range(len(Grid)-1)] 
    
    #---------------------------------------------------------------------------
    
    iteration_counter = 1
    
    y_t = RKM_Dyn_System(y_0, u_t, Arguments, iteration_counter)
    
    (Integral, Terminal) = Objective_Functional(y_t, Arguments)
    
    return Integral + Terminal




def Control_Eq_Constraint(u_t_vector, index, Arguments):
    
    """
    This function computes the equality constraints for the direct method optimization problem and ensures that the total fraction of
    bots is equal to 1 - n_{ [ M ] } (in the current notations - (1-n))
    
    Input:
    Dependent_Variable 
    u_t_vector         - the control function in vector representation
    index              - the index of the control matrix on which the constraint should be imposed 
                         (in matrix representation)
    Arguments          - hyperparameters
    """
    
    m = Arguments['m']
    M = Arguments['M']
    
    n = Arguments['n']  # the fraction of ordinary agents in the system
    
    # The grid
    
    Grid = Arguments['Grid']
    
    #---------------------------------------------------------------------------
    
    return u_t_vector[ index*m*M : (index+1)*m*M ].sum() - (1 - n)



def DM(y_0, u_t_0, Arguments):
    
    """
    This function solves a minimzation problem in which the target functional of the control problem
    (see formula (11) in the main manuscript) is minimzed with respect to the control function u_t
    over the set of all possible control functions (see formulas (9), (10) in the main manuscript)
    
    Input:
    y_0       - the starting point of the system (see formula (4) from the main manuscript)
    u_t_0     - the initial guess for the control function
    Arguments - hyperparameters
    
    Output:
    u_t       - the best control function founded 
    """
    
    m = Arguments['m']
    M = Arguments['M']
    
    # The grid
    
    Grid = Arguments['Grid']
    
    #---------------------------------------------------------------------------
    
    # Reshape the guess for control function into vector form 
    # (procedure minimize from scipy.optimize needs vector inputs)
    
    # Get a list of vectors (of lenght m*M)
        
    u_t_0_vectors = [u_t_0[i].reshape(m*M, ) for i in range(len(Grid)-1)] 
    
    # Get a single vector
    
    u_t_0_vector = np.hstack(u_t_0_vectors)
    
    #---------------------------------------------------------------------------
        
    # Define bound-type constraints

    Bounds = [(0, None) for i in range(len(u_t_0_vector))]
    
    Bounds = tuple(Bounds)
    
    # Type constraints
    
    Constraints = [{'type': 'eq', 
                    'fun': Control_Eq_Constraint, 
                    'args': (i, Arguments, )} for i in range(len(Grid)-1)]
    
    Constraints = tuple(Constraints)
    
    #---------------------------------------------------------------------------
    
    # Run minimization

    res = minimize(fun=Objective_Functional_Direct_Method, 
                   x0=u_t_0_vector, 
                   bounds=Bounds, 
                   constraints=Constraints, 
                   args=(y_0, Arguments, ))
    
    ArgMax = res.x
    
    # Transform the output into matrix form
    
    # Get a list of vectors (of lenght m*M)
    
    u_t_vectors = [ArgMax[ i*m*M : (i+1)*m*M ] for i in range(len(Grid)-1)]
    
    # Reshape them into m*M matrices
    
    u_t = [u_t_vectors[i].reshape(m, M) for i in range(len(Grid)-1)] 
    
    return u_t






